#!/usr/bin/env python3
# Copyright 2019-2022 Alibaba Group Holding Limited.
# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause

from typing import Dict
import sh
import os
import subprocess
import sys
from glob import glob

class TestVarUniformity:
    def setup_class(self):
        print("Var uniformity test")
        self.global_name = [
            "/proc/sys/kernel/sched_child_runs_first",
            "/proc/sys/kernel/sched_min_granularity_ns",
            "/proc/sys/kernel/sched_latency_ns",
            "/proc/sys/kernel/sched_wakeup_granularity_ns",
            "/proc/sys/kernel/sched_tunable_scaling",
            "/proc/sys/kernel/sched_migration_cost_ns",
            "/proc/sys/kernel/sched_nr_migrate",
            "/proc/sys/kernel/sched_schedstats",
            "/proc/sys/kernel/numa_balancing_scan_delay_ms",
            "/proc/sys/kernel/numa_balancing_scan_period_min_ms",
            "/proc/sys/kernel/numa_balancing_scan_period_max_ms",
            "/proc/sys/kernel/numa_balancing_scan_size_mb",
            "/proc/sys/kernel/numa_balancing",
            "/proc/sys/kernel/sched_rt_period_us",
            "/proc/sys/kernel/sched_rr_timeslice_ms",
            "/proc/sys/kernel/sched_autogroup_enabled",
            "/proc/sys/kernel/sched_cfs_bandwidth_slice_us",
            "/sys/kernel/debug/sched_debug",
        ]
        
    def before_change(self):
        self.orig_data = {}
        self.record_data(self.orig_data)
        self.load_scheduler()
        self.data_after_load = {}
        self.record_data(self.data_after_load)
    
    def record_data(self, dict: Dict):
        for item in self.global_name:
            if not os.path.exists(item):
                continue
            dict[item] = str(sh.cat(item)).strip()

    def load_scheduler(self):
        scheduler_rpm = glob(os.path.join('/tmp/work', 'scheduler*.rpm'))
        if len(scheduler_rpm) != 1:
            print("Please check your scheduler rpm");
            self.teardown_class()
            sys.exit(1)
        scheduler_rpm = scheduler_rpm[0]
        sh.rpm('-ivh', scheduler_rpm)

    def after_change_unload(self):
        self.modify_data()
        self.data_after_modified = {}
        self.record_data(self.data_after_modified)
        sh.rpm('-e', 'scheduler-xxx')
        self.data_after_unload = {}
        self.record_data(self.data_after_unload)

    def modify_data(self):
        def reverse(ch):
            if ch.isdigit():
                return '1' if ch == '0' else str(int(ch) - 1)
            return 'N' if ch == 'Y' else 'Y'

        for k, v in self.orig_data.items():
            sh.echo(reverse(v), _out=k)

    def teardown_class(self):
        tmp = subprocess.Popen("lsmod | grep scheduler", shell=True, stdout=subprocess.PIPE)
        if tmp.stdout.read() != b'':
            sh.rpm('-e', 'scheduler-xxx')
        for k, v in self.orig_data.items():
            sh.echo(v, _out=k)

    def test_data_uniformity(self):
        self.before_change()
        if not self.orig_data == self.data_after_load:
            self.error_handler()
        self.after_change_unload()
        if not self.data_after_modified == self.data_after_unload:
            self.error_handler()

    def error_handler(self):
        self.teardown_class()
        sys.exit(1)

if __name__ == '__main__':
    unit_test = TestVarUniformity()
    unit_test.setup_class()
    unit_test.test_data_uniformity()
    unit_test.teardown_class()
